# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
# https://github.com/gepa-ai/gepa

from typing import Any, TypedDict

from gepa.adapters.generic_rag_adapter.evaluation_metrics import RAGEvaluationMetrics
from gepa.adapters.generic_rag_adapter.rag_pipeline import RAGPipeline
from gepa.adapters.generic_rag_adapter.vector_store_interface import VectorStoreInterface
from gepa.core.adapter import EvaluationBatch, GEPAAdapter


class RAGDataInst(TypedDict):
    """
    Data instance for RAG evaluation and optimization.

    This TypedDict defines the structure for training and validation examples
    used in RAG system optimization with GEPA.

    Attributes:
            query (str): User query or question to be answered
        ground_truth_answer (str): Expected/correct answer for evaluation
            relevant_doc_ids (List[str]): List of document IDs that should ideally
            be retrieved for this query (used for retrieval evaluation)
            metadata (Dict[str, Any]): Additional context, tags, or configuration
            specific to this example (e.g., difficulty level, category)

    Example:
            .. code-block:: python

            data_inst = RAGDataInst(
                query="What is machine learning?",
                ground_truth_answer="Machine learning is a subset of AI...",
                relevant_doc_ids=["doc_001", "doc_042"],
                metadata={"category": "AI", "difficulty": "beginner"}
            )
    """

    query: str
    ground_truth_answer: str
    relevant_doc_ids: list[str]
    metadata: dict[str, Any]


class RAGTrajectory(TypedDict):
    """
    Detailed trajectory capturing all RAG pipeline execution steps.

    This TypedDict captures the complete execution trace of the RAG pipeline,
    providing visibility into each step for analysis and optimization.

    Attributes:
            original_query (str): Original user query as provided
        reformulated_query (str): Query after reformulation step (if enabled)
            retrieved_docs (List[Dict[str, Any]]): Documents retrieved from vector store
            with their content, metadata, and similarity scores
            synthesized_context (str): Context after document synthesis step
            generated_answer (str): Final answer generated by the LLM
            execution_metadata (Dict[str, Any]): Pipeline execution metadata including
            retrieval metrics, generation metrics, token counts, and performance data

    Note:
            Trajectories are only captured when capture_traces=True is passed to
        the evaluate() method, as they can be memory-intensive for large batches.
    """

    original_query: str
    reformulated_query: str
    retrieved_docs: list[dict[str, Any]]
    synthesized_context: str
    generated_answer: str
    execution_metadata: dict[str, Any]


class RAGOutput(TypedDict):
    """
    Final output from RAG system execution.

    This TypedDict represents the final result of RAG pipeline execution,
    containing both the generated answer and associated metadata.

    Attributes:
            final_answer (str): The generated answer from the RAG system
            confidence_score (float): Estimated confidence in the answer (0.0 to 1.0)
            based on retrieval quality and generation metrics
            retrieved_docs (List[Dict[str, Any]]): Documents that were retrieved
            and used for answer generation
        total_tokens (int): Estimated total token usage for the pipeline execution

    Example:
            .. code-block:: python

            output = RAGOutput(
                final_answer="Machine learning is a method of data analysis...",
                confidence_score=0.87,
                retrieved_docs=[{"content": "...", "score": 0.9}],
                total_tokens=450
            )
    """

    final_answer: str
    confidence_score: float
    retrieved_docs: list[dict[str, Any]]
    total_tokens: int


class GenericRAGAdapter(GEPAAdapter[RAGDataInst, RAGTrajectory, RAGOutput]):
    """
    Generic GEPA adapter for RAG system optimization with pluggable vector stores.

    This adapter enables GEPA's evolutionary prompt optimization to work with any
    vector store implementation through the VectorStoreInterface. It provides
    comprehensive evaluation of both retrieval and generation quality.

    Optimizable Components:
            - Query reformulation prompts: Improve query understanding and reformulation
            - Context synthesis prompts: Optimize document combination and summarization
            - Answer generation prompts: Enhance final answer quality and formatting
            - Reranking criteria: Improve document relevance ordering

    Evaluation Metrics:
            - Retrieval Quality: Precision, recall, F1, mean reciprocal rank (MRR)
            - Generation Quality: Token F1, BLEU score, faithfulness, answer relevance
        - Combined Score: Weighted combination for overall system performance

    Vector Store Support:
            Works with any vector store implementing VectorStoreInterface, including:
            ChromaDB, Weaviate, Qdrant, Pinecone, Milvus, and custom implementations.

    Example:
            .. code-block:: python

            from gepa.adapters.generic_rag_adapter import GenericRAGAdapter, ChromaVectorStore
            import gepa

            vector_store = ChromaVectorStore.create_local("./kb", "docs")
            adapter = GenericRAGAdapter(vector_store=vector_store, llm_model="gpt-4")

            result = gepa.optimize(
                seed_candidate={"answer_generation": "Answer based on context:"},
                trainset=train_data,
                valset=val_data,
                adapter=adapter,
                max_metric_calls=50
            )
            print(result.best_candidate)  # Optimized prompts
    """

    def __init__(
        self,
        vector_store: VectorStoreInterface,
        llm_model,
        embedding_model: str = "text-embedding-3-small",
        embedding_function=None,
        rag_config: dict[str, Any] | None = None,
        failure_score: float = 0.0,
    ):
        """
        Initialize the GenericRAGAdapter for RAG system optimization.

        Args:
            vector_store: Vector store implementation (ChromaDB, Weaviate, etc.)
                Must implement VectorStoreInterface for similarity search operations.
            llm_model: LLM client for text generation. Can be:
                - String model name (uses litellm for inference)
                - Callable that takes messages and returns response text
                - Any object with a callable interface for LLM inference
            embedding_model: Model name for text embeddings (default: "text-embedding-3-small").
                Used when embedding_function is not provided.
            embedding_function: Optional custom embedding function that takes text
                and returns List[float]. If None, uses default litellm embeddings.
            rag_config: RAG pipeline configuration dictionary. Keys include:
                - "retrieval_strategy": "similarity", "hybrid", or "vector"
                - "top_k": Number of documents to retrieve (default: 5)
                - "retrieval_weight": Weight for retrieval in combined score (default: 0.3)
                - "generation_weight": Weight for generation in combined score (default: 0.7)
                - "hybrid_alpha": Semantic vs keyword balance for hybrid search (default: 0.5)
                - "filters": Default metadata filters for retrieval
            failure_score: Score assigned when evaluation fails (default: 0.0)

        Example:
            .. code-block:: python

                vector_store = WeaviateVectorStore.create_local(collection_name="docs")
                adapter = GenericRAGAdapter(
                    vector_store=vector_store,
                    llm_model="gpt-4",
                    rag_config={
                        "retrieval_strategy": "hybrid",
                        "top_k": 5,
                        "hybrid_alpha": 0.7
                    }
                )
        """
        self.vector_store = vector_store
        self.rag_pipeline = RAGPipeline(
            vector_store=vector_store,
            llm_client=llm_model,
            embedding_model=embedding_model,
            embedding_function=embedding_function,
        )
        self.evaluator = RAGEvaluationMetrics()
        self.config = rag_config or self._default_config()
        self.failure_score = failure_score

    def evaluate(
        self,
        batch: list[RAGDataInst],
        candidate: dict[str, str],
        capture_traces: bool = False,
    ) -> EvaluationBatch[RAGTrajectory, RAGOutput]:
        """
        Evaluate RAG system performance on a batch of query-answer examples.

        This method runs the complete RAG pipeline on each example in the batch,
        evaluating both retrieval and generation quality using the provided
        prompt components.

        Args:
            batch: List of RAG evaluation examples, each containing:
                - query: Question to answer
                - ground_truth_answer: Expected correct answer
                - relevant_doc_ids: Documents that should be retrieved
                - metadata: Additional context for evaluation
            candidate: Dictionary mapping prompt component names to their text.
                Supported components:
                - "query_reformulation": Prompt for improving user queries
                - "context_synthesis": Prompt for combining retrieved documents
                - "answer_generation": Prompt for generating final answers
                - "reranking_criteria": Criteria for reordering retrieved documents
            capture_traces: If True, capture detailed execution trajectories
                for each example. Required for reflective dataset generation but
                increases memory usage.

        Returns:
            EvaluationBatch containing:
            - outputs: List of RAGOutput for each example
            - scores: List of combined quality scores (higher = better)
            - trajectories: List of detailed execution traces (if capture_traces=True)

        Raises:
            Exception: Individual example failures are caught and assigned failure_score.
                Only systemic failures (e.g., vector store unavailable) raise exceptions.

        Example:
            .. code-block:: python

                prompts = {
                    "answer_generation": "Answer the question based on this context:"
                }
                result = adapter.evaluate(
                    batch=validation_data,
                    candidate=prompts,
                    capture_traces=True
                )
                avg_score = sum(result.scores) / len(result.scores)
                print(f"Average RAG performance: {avg_score:.3f}")
        """
        outputs: list[RAGOutput] = []
        scores: list[float] = []
        trajectories: list[RAGTrajectory] | None = [] if capture_traces else None

        for data_inst in batch:
            try:
                # Execute RAG pipeline with candidate prompts
                rag_result = self.rag_pipeline.execute_rag(
                    query=data_inst["query"], prompts=candidate, config=self.config
                )

                # Evaluate retrieval quality
                retrieval_metrics = self.evaluator.evaluate_retrieval(
                    rag_result["retrieved_docs"], data_inst["relevant_doc_ids"]
                )

                # Evaluate generation quality
                generation_metrics = self.evaluator.evaluate_generation(
                    rag_result["generated_answer"], data_inst["ground_truth_answer"], rag_result["synthesized_context"]
                )

                # Calculate combined score
                overall_score = self.evaluator.combined_rag_score(
                    retrieval_metrics,
                    generation_metrics,
                    retrieval_weight=self.config.get("retrieval_weight", 0.3),
                    generation_weight=self.config.get("generation_weight", 0.7),
                )

                # Prepare output
                output = RAGOutput(
                    final_answer=rag_result["generated_answer"],
                    confidence_score=generation_metrics.get("answer_confidence", 0.5),
                    retrieved_docs=rag_result["retrieved_docs"],
                    total_tokens=rag_result["metadata"]["total_tokens"],
                )

                outputs.append(output)
                scores.append(overall_score)

                # Capture trajectory if requested
                if capture_traces:
                    trajectory = RAGTrajectory(
                        original_query=rag_result["original_query"],
                        reformulated_query=rag_result["reformulated_query"],
                        retrieved_docs=rag_result["retrieved_docs"],
                        synthesized_context=rag_result["synthesized_context"],
                        generated_answer=rag_result["generated_answer"],
                        execution_metadata={
                            **rag_result["metadata"],
                            "retrieval_metrics": retrieval_metrics,
                            "generation_metrics": generation_metrics,
                            "overall_score": overall_score,
                        },
                    )
                    trajectories.append(trajectory)

            except Exception as e:
                # Handle individual example failure
                error_output = RAGOutput(
                    final_answer=f"Error: {e!s}", confidence_score=0.0, retrieved_docs=[], total_tokens=0
                )

                outputs.append(error_output)
                scores.append(self.failure_score)

                if capture_traces:
                    error_trajectory = RAGTrajectory(
                        original_query=data_inst["query"],
                        reformulated_query=data_inst["query"],
                        retrieved_docs=[],
                        synthesized_context="",
                        generated_answer=f"Error: {e!s}",
                        execution_metadata={"error": str(e)},
                    )
                    trajectories.append(error_trajectory)

        return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)

    def make_reflective_dataset(
        self,
        candidate: dict[str, str],
        eval_batch: EvaluationBatch[RAGTrajectory, RAGOutput],
        components_to_update: list[str],
    ) -> dict[str, list[dict[str, Any]]]:
        """
        Generate reflective dataset for evolutionary prompt optimization.

        This method analyzes the evaluation results and creates training examples
        that GEPA's proposer can use to improve the specified prompt components.
        Each component gets a tailored dataset with input-output pairs and feedback.

        Args:
            candidate: Current prompt components that were evaluated
            eval_batch: Evaluation results from evaluate() with capture_traces=True.
                Must contain trajectories for analysis.
            components_to_update: List of component names to generate improvement
                suggestions for. Must be subset of candidate.keys().

        Returns:
            Dictionary mapping component names to their reflective datasets.
            Each dataset is a list of examples with structure:
            - "Inputs": Input data for the component (query, docs, etc.)
            - "Generated Outputs": What the component currently produces
            - "Feedback": Analysis of performance and suggestions for improvement

        Example:
            .. code-block:: python

                reflective_data = adapter.make_reflective_dataset(
                    candidate=current_prompts,
                    eval_batch=evaluation_results,  # with trajectories
                    components_to_update=["answer_generation", "context_synthesis"]
                )
                print(reflective_data["answer_generation"][0]["Feedback"])
                # Output: "The generated answer lacks specific details from the context..."

        Note:
            This method requires eval_batch to have been created with
            capture_traces=True, otherwise trajectories will be None.
        """
        reflective_data: dict[str, list[dict[str, Any]]] = {}

        for component in components_to_update:
            component_examples = []

            # Process each trajectory to create examples for this component
            for traj, output, score in zip(
                eval_batch.trajectories or [], eval_batch.outputs, eval_batch.scores, strict=False
            ):
                example = self._create_component_example(component, traj, output, score, candidate)
                if example:
                    component_examples.append(example)

            # Only include components that have examples
            if component_examples:
                reflective_data[component] = component_examples

        return reflective_data

    def _create_component_example(
        self, component_name: str, trajectory: RAGTrajectory, output: RAGOutput, score: float, candidate: dict[str, str]
    ) -> dict[str, Any] | None:
        """Create a reflective example for a specific component."""

        if component_name == "query_reformulation":
            return {
                "Inputs": {
                    "original_query": trajectory["original_query"],
                    "current_prompt": candidate.get(component_name, ""),
                },
                "Generated Outputs": trajectory["reformulated_query"],
                "Feedback": self._generate_query_reformulation_feedback(trajectory, score),
            }

        elif component_name == "context_synthesis":
            return {
                "Inputs": {
                    "query": trajectory["original_query"],
                    "retrieved_docs": [doc["content"] for doc in trajectory["retrieved_docs"]],
                    "current_prompt": candidate.get(component_name, ""),
                },
                "Generated Outputs": trajectory["synthesized_context"],
                "Feedback": self._generate_context_synthesis_feedback(trajectory, score),
            }

        elif component_name == "answer_generation":
            return {
                "Inputs": {
                    "query": trajectory["original_query"],
                    "context": trajectory["synthesized_context"],
                    "current_prompt": candidate.get(component_name, ""),
                },
                "Generated Outputs": trajectory["generated_answer"],
                "Feedback": self._generate_answer_generation_feedback(trajectory, output, score),
            }

        elif component_name == "reranking_criteria":
            return {
                "Inputs": {
                    "query": trajectory["original_query"],
                    "documents": [doc["content"] for doc in trajectory["retrieved_docs"]],
                    "current_criteria": candidate.get(component_name, ""),
                },
                "Generated Outputs": "Document ranking applied",
                "Feedback": self._generate_reranking_feedback(trajectory, score),
            }

            return None

    def _generate_query_reformulation_feedback(self, trajectory: RAGTrajectory, score: float) -> str:
        """Generate feedback for query reformulation component."""
        if score > 0.7:
            return f"Good query reformulation. The reformulated query '{trajectory['reformulated_query']}' helped retrieve relevant documents and generated a good answer."
        else:
            return f"The query reformulation from '{trajectory['original_query']}' to '{trajectory['reformulated_query']}' may not have improved retrieval. Consider making the reformulated query more specific or preserving key terms."

    def _generate_context_synthesis_feedback(self, trajectory: RAGTrajectory, score: float) -> str:
        """Generate feedback for context synthesis component."""
        if score > 0.7:
            return "Context synthesis worked well - the synthesized context effectively supported answer generation."
        else:
            return "Context synthesis could be improved. The synthesized context may not have highlighted the most relevant information or may have been too verbose/concise."

    def _generate_answer_generation_feedback(self, trajectory: RAGTrajectory, output: RAGOutput, score: float) -> str:
        """Generate feedback for answer generation component."""
        if score > 0.7:
            return f"Good answer generation. The generated answer '{trajectory['generated_answer']}' was accurate and well-supported by the context."
        else:
            return f"Answer generation needs improvement. The generated answer '{trajectory['generated_answer']}' may not be fully accurate or well-supported by the provided context."

    def _generate_reranking_feedback(self, trajectory: RAGTrajectory, score: float) -> str:
        """Generate feedback for reranking criteria component."""
        if score > 0.7:
            return "Document reranking appears to have helped surface more relevant documents for answer generation."
        else:
            return "Document reranking may not have improved relevance. Consider adjusting the criteria to better prioritize documents that contain the answer."

    def _default_config(self) -> dict[str, Any]:
        """
        Get default configuration for RAG pipeline.

        Returns:
            Dictionary with default RAG configuration parameters:
            - retrieval_strategy: "similarity" (semantic search)
            - top_k: 5 (number of documents to retrieve)
            - retrieval_weight: 0.3 (30% weight for retrieval metrics)
            - generation_weight: 0.7 (70% weight for generation metrics)
            - hybrid_alpha: 0.5 (balanced semantic/keyword for hybrid search)
            - filters: None (no metadata filtering by default)
        """
        return {
            "retrieval_strategy": "similarity",
            "top_k": 5,
            "retrieval_weight": 0.3,
            "generation_weight": 0.7,
            "hybrid_alpha": 0.5,
            "filters": None,
        }
